import sys
import torch
import torch_npu
from   common    import AclInfer

mdl = AclInfer(model_path="../model/resnet50.om", device_id=0)
mdl.init()

input_desc  = mdl.getInputDescByIndex(0)
output_desc = mdl.getOutputDescByIndex(0)
print(input_desc)
print(output_desc)

input   = torch.randn(1,3,224,224, dtype=torch.float32).npu() * 1000
output  = mdl(input, check_inputs=True, loops=500, warmup=20)

print(output.shape, output.dtype, output.device)
# print(output)

mdl.deinit()